{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import theano\n",
    "import theano.tensor as T\n",
    "from theano import shared \n",
    "from collections import OrderedDict\n",
    "\n",
    "dtype=T.config.floatX\n",
    "theano.config.optimizer='fast_compile'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "#Embedded Reber Grammar http://christianherta.de/lehre/dataScience/machineLearning/neuralNetworks/reberGrammar.php\n",
    "\n",
    "chars='BTSXPVE'\n",
    "\n",
    "graph = [[(1,5),('T','P')] , [(1,2),('S','X')], \\\n",
    "           [(3,5),('S','X')], [(6,),('E')], \\\n",
    "           [(3,2),('V','P')], [(4,5),('V','T')] ]\n",
    "\n",
    "\n",
    "def in_grammar(word):\n",
    "    if word[0] != 'B':\n",
    "        return False\n",
    "    node = 0    \n",
    "    for c in word[1:]:\n",
    "        transitions = graph[node]\n",
    "        try:\n",
    "            node = transitions[0][transitions[1].index(c)]\n",
    "        except ValueError: # using exceptions for flow control in python is common\n",
    "            return False\n",
    "    return True        \n",
    "      \n",
    "def sequenceToWord(sequence):\n",
    "    \"\"\"\n",
    "    converts a sequence (one-hot) in a reber string\n",
    "    \"\"\"\n",
    "    reberString = ''\n",
    "    for s in sequence:\n",
    "        index = np.where(s==1.)[0][0]\n",
    "        reberString += chars[index]\n",
    "    return reberString\n",
    "    \n",
    "def generateSequences(minLength):\n",
    "    while True:\n",
    "        inchars = ['B']\n",
    "        node = 0\n",
    "        outchars = []    \n",
    "        while node != 6:\n",
    "            transitions = graph[node]\n",
    "            i = np.random.randint(0, len(transitions[0]))\n",
    "            inchars.append(transitions[1][i])\n",
    "            outchars.append(transitions[1])\n",
    "            node = transitions[0][i]\n",
    "        if len(inchars) > minLength:  \n",
    "            return inchars, outchars\n",
    "\n",
    "\n",
    "def get_one_example(minLength):\n",
    "    inchars, outchars = generateSequences(minLength)\n",
    "    inseq = []\n",
    "    outseq= []\n",
    "    for i,o in zip(inchars, outchars): \n",
    "        inpt = np.zeros(7)\n",
    "        inpt[chars.find(i)] = 1.     \n",
    "        outpt = np.zeros(7)\n",
    "        for oo in o:\n",
    "            outpt[chars.find(oo)] = 1.\n",
    "        inseq.append(inpt)\n",
    "        outseq.append(outpt)\n",
    "    return inseq, outseq\n",
    "\n",
    "\n",
    "def get_char_one_hot(char):\n",
    "    char_oh = np.zeros(7)\n",
    "    for c in char:\n",
    "        char_oh[chars.find(c)] = 1.\n",
    "    return [char_oh] \n",
    "    \n",
    "def get_n_examples(n, minLength=10):\n",
    "    examples = []\n",
    "    for i in xrange(n):\n",
    "        examples.append(get_one_example(minLength))\n",
    "    return examples\n",
    "\n",
    "emb_chars = \"TP\"\n",
    "\n",
    "\n",
    "def get_one_embedded_example(minLength=10):\n",
    "    i, o = get_one_example(minLength)\n",
    "    emb_char = emb_chars[np.random.randint(0, len(emb_chars))]\n",
    "    new_in = get_char_one_hot(('B',))\n",
    "    new_in += get_char_one_hot((emb_char,))\n",
    "    new_out= get_char_one_hot(emb_chars)\n",
    "    new_out+= get_char_one_hot('B',)\n",
    "    new_in += i\n",
    "    new_out += o\n",
    "    new_in += get_char_one_hot(('E',))\n",
    "    new_in += get_char_one_hot((emb_char,))\n",
    "    new_out += get_char_one_hot((emb_char, ))\n",
    "    new_out += get_char_one_hot(('E',))\n",
    "    return new_in, new_out\n",
    "    \n",
    "def get_n_embedded_examples(n, minLength=10):\n",
    "    examples = []\n",
    "    for i in xrange(n):\n",
    "        examples.append(get_one_embedded_example(minLength))\n",
    "    return examples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_data = get_n_embedded_examples(1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "def sample_weights(sizeX, sizeY):\n",
    "    W = np.random.uniform(low=-1., high=1., size=(sizeX, sizeY))\n",
    "    _, svs, _ = np.linalg.svd(W)\n",
    "    values = np.asarray(W / svs[0], dtype=dtype)\n",
    "    return shared(values, borrow=True) \n",
    "\n",
    "class LSTM:\n",
    "    def __init__(self, n_in, n_lstm, n_out):        \n",
    "        self.n_in = n_in\n",
    "        self.n_lstm = n_lstm\n",
    "        self.n_out = n_out\n",
    "        self.W_xi = sample_weights(n_in, n_lstm)\n",
    "        self.W_hi = sample_weights(n_lstm, n_lstm)\n",
    "        self.W_ci = sample_weights(n_lstm, n_lstm)\n",
    "        self.b_i = shared(np.cast[dtype](np.random.uniform(-0.5,.5,size = n_lstm)))\n",
    "        self.W_xf = sample_weights(n_in, n_lstm)\n",
    "        self.W_hf = sample_weights(n_lstm, n_lstm)\n",
    "        self.W_cf = sample_weights(n_lstm, n_lstm)\n",
    "        self.b_f = shared(np.cast[dtype](np.random.uniform(0, 1.,size = n_lstm)))\n",
    "        self.W_xc = sample_weights(n_in, n_lstm)\n",
    "        self.W_hc = sample_weights(n_lstm, n_lstm)\n",
    "        self.b_c = shared(np.zeros(n_lstm, dtype=dtype))\n",
    "        self.W_xo = sample_weights(n_in, n_lstm)\n",
    "        self.W_ho = sample_weights(n_lstm, n_lstm)\n",
    "        self.W_co = sample_weights(n_lstm, n_lstm)\n",
    "        self.b_o = shared(np.cast[dtype](np.random.uniform(-0.5,.5,size = n_lstm)))\n",
    "        self.W_hy = sample_weights(n_lstm, n_out)\n",
    "        self.b_y = shared(np.zeros(n_out, dtype=dtype))\n",
    "        self.params = [self.W_xi, self.W_hi, self.W_ci, self.b_i, \n",
    "                       self.W_xf, self.W_hf, self.W_cf, self.b_f, \n",
    "                       self.W_xc, self.W_hc, self.b_c, \n",
    "                       self.W_ho, self.W_co, self.W_co, self.b_o, \n",
    "                       self.W_hy, self.b_y]\n",
    "                \n",
    "\n",
    "        def step_lstm(x_t, h_tm1, c_tm1):\n",
    "            i_t = T.nnet.sigmoid(T.dot(x_t, self.W_xi) + T.dot(h_tm1, self.W_hi) + T.dot(c_tm1, self.W_ci) + self.b_i)\n",
    "            f_t = T.nnet.sigmoid(T.dot(x_t, self.W_xf) + T.dot(h_tm1, self.W_hf) + T.dot(c_tm1, self.W_cf) + self.b_f)\n",
    "            c_t = f_t * c_tm1 + i_t * T.tanh(T.dot(x_t, self.W_xc) + T.dot(h_tm1, self.W_hc) + self.b_c) \n",
    "            o_t = T.nnet.sigmoid(T.dot(x_t, self.W_xo)+ T.dot(h_tm1, self.W_ho) + T.dot(c_t, self.W_co)  + self.b_o)\n",
    "            h_t = o_t * T.tanh(c_t)\n",
    "            y_t = T.nnet.sigmoid(T.dot(h_t, self.W_hy) + self.b_y) \n",
    "            return [h_t, c_t, y_t]\n",
    "        \n",
    "        X = T.matrix() # X is a sequence of vector   \n",
    "        Y = T.matrix() # Y is a sequence of vector\n",
    "        h0 = shared(np.zeros(self.n_lstm, dtype=dtype)) # initial hidden state \n",
    "        c0 = shared(np.zeros(self.n_lstm, dtype=dtype)) # initial cell state\n",
    "        \n",
    "        [h_vals, c_vals, y_vals], _ = theano.scan(fn=step_lstm,                                  \n",
    "                                                  sequences=X,\n",
    "                                                  outputs_info=[h0, c0, None])\n",
    "        \n",
    "        self.output = y_vals\n",
    "    \n",
    "        cost = -T.mean(Y * T.log(y_vals)+ (1.- Y) * T.log(1. - y_vals))\n",
    "        lr = shared(np.cast[dtype](0.1))\n",
    "        gparams = T.grad(cost, self.params)\n",
    "        updates = OrderedDict()\n",
    "        for param, gparam in zip(self.params, gparams):\n",
    "            updates[param] = param - gparam * lr\n",
    "        self.train = theano.function(inputs = [X, Y], outputs = cost, updates=updates) \n",
    "        \n",
    "        self.pred = theano.function(inputs = [X], outputs = self.output)                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = LSTM(7, 50, 7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "nb_epochs = 100\n",
    "#stupid and naive sgd\n",
    "for x in range(nb_epochs):\n",
    "    error = 0.\n",
    "    for j in range(len(train_data)):  \n",
    "        index = np.random.randint(0, len(train_data))\n",
    "        i, o = train_data[index]\n",
    "        train_cost = model.train(i, o)\n",
    "        error += train_cost\n",
    "    if x%10==0:\n",
    "            print \"epoch \"+str(x)+ \" error: \"+str(error)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "test_data = get_n_embedded_examples(10)\n",
    "\n",
    "def print_out(test_data):\n",
    "    for i,o in test_data:\n",
    "        p = model.pred(i)\n",
    "        print o[-2] # target\n",
    "        print np.asarray([0. if x!=np.argmax(p[-2]) else 1. for x in range(7)]) # prediction\n",
    "        print \n",
    "print_out(test_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
